#include "ssd_priorbox.h"

float coordinate_clip(float down, float up, float input) // 可惜c里面不能函数重载
{
	input = (input >= down)? input : down;
	input = (input <= up)? input : up;
	return input;
}

int coordinate_clip_int(int down, int up, int input) // 不能过界，否则画框时会报错
{
	input = (input >= down)? input : down;
	input = (input <= up)? input : up;
	return input;
}

void free_bbox(Node *predbbox){ // 释放链表内存函数

	Node *temp = NULL;
	while (predbbox != NULL){
		
		temp = predbbox;
		predbbox = predbbox->next;
		free(temp);
	}
}

/* 排序，降序排列 */
void pred_bbox_sort(Node *predbbox){
	
	PredBbox temp_bbox = {0};
	int temp_confidence = 0;
	Node *current_node = predbbox->next;
	int current_confidence = 0;
	Node *next_node = NULL;
	int next_confidence = 0;
	// 下面是链表的常规操作，不多做解释了
	while (current_node != NULL){
		current_confidence = current_node->confidence;
		next_node = current_node->next;
		
		while (next_node != NULL){
			next_confidence = next_node->confidence; 
			
			if (current_confidence > next_confidence){ // 如果大于它，说明后面的比它小，比较下一个
				next_node = next_node->next;
				
			}else{
				// 当大于 current_confidence 时，调换位置，小的放后面去
				memcpy(&(temp_bbox), &(current_node->bbox), sizeof(PredBbox));
				temp_confidence = current_node->confidence;
				memcpy(&(current_node->bbox), &(next_node->bbox), sizeof(PredBbox));
				current_node->confidence = next_node->confidence;
				memcpy(&(next_node->bbox), &(temp_bbox), sizeof(PredBbox));
				next_node->confidence = temp_confidence;
				
				memset(&temp_bbox, 0, sizeof(PredBbox)); // 重新置为 0
				next_node = next_node->next;
				current_confidence = current_node->confidence; // 这句是关键，少了就没用了，位置交换后，该值也更新了
			}
			
		}
		current_node = current_node->next;
	}
	
	return;
}

void pred_bbox_nms(Node *predbbox, float nms){
	
	//float nms = 0.4f;
	Node *current_node = predbbox->next;
	Node *next_node = NULL;
	Node *last_node = NULL;
	float overlap_x_left = 0.0f, overlap_y_left = 0.0f, overlap_x_right = 0.0f, overlap_y_right = 0.0f;
	float w = 0.0f, h = 0.0f;
	float inter_area = 0.0f;
	float current_area = 0.0f, next_area = 0.0f, overlap_area = 0.0f;
	float nms_ratio = 0.0f;
	// 操作同上面差不多，常规操作，此时数据已根据置信度降序排列好
	while (current_node != NULL){
		last_node = current_node;
		next_node = last_node->next;
		current_area = (current_node->bbox.bbox_x_right - current_node->bbox.bbox_x_left) * (current_node->bbox.bbox_y_right - current_node->bbox.bbox_y_left);
		while (next_node != NULL){
			// 下面就是置信度计算
			overlap_x_left = ssd_MAX(current_node->bbox.bbox_x_left, next_node->bbox.bbox_x_left);
			overlap_y_left = ssd_MAX(current_node->bbox.bbox_y_left, next_node->bbox.bbox_y_left);
			overlap_x_right = ssd_MIN(current_node->bbox.bbox_x_right, next_node->bbox.bbox_x_right);
			overlap_y_right = ssd_MIN(current_node->bbox.bbox_y_right, next_node->bbox.bbox_y_right);
			w = ssd_MAX((overlap_x_right - overlap_x_left), 0.0F);
			h = ssd_MAX((overlap_y_right - overlap_y_left), 0.0F);
			overlap_area = w * h;
			
			next_area = (next_node->bbox.bbox_x_right - next_node->bbox.bbox_x_left) * (next_node->bbox.bbox_y_right - next_node->bbox.bbox_y_left);
			nms_ratio = overlap_area / (current_area + next_area - overlap_area);
			//printf("nms_ratio = %f\n", nms_ratio);
			if (nms_ratio > nms){
				last_node->next = last_node->next->next; // 指向下一块内存，方便下面去掉当前节点
				free(next_node); // 把大于阈值的数据释放掉
				next_node = last_node->next;
			}else{
				last_node = last_node->next; // last_node 记得也要同步更新
				next_node = next_node->next;
				
			}
		}
		current_node = current_node->next;
	}
}

int priorbox_init(PriorBox *priorbox)
{
	float input_image_size = INPUT_SIZE;
	// 下面是 ssd 基本配置参数，可看 博客/论文 确认，可别参数根据场景需要改了一下，不一样是正常的
	float feature_map_size[LONGTH] = {38.0f, 19.0f, 10.0f, 5.0f, 3.0f, 1.0f};
	
	float min_sk[LONGTH] = {30.0f, 60.0f, 111.0f, 162.0f, 213.0f, 264.0f};
	
	float max_sk[LONGTH] = {60.0f, 111.0f, 162.0f, 213.0f, 264.0f, 315.0f};
	
	float steps[LONGTH] = {8.0f, 16.0f, 32.0f, 64.0f, 100.0f, 300.0f}; // 下采样倍数

	float aT_1[a1_LONGTH] = {0.2f, 0.5f, 0.8f, 1.1f, 0.6f}; // 长宽比

	float aT_2[a2_LONGTH] = {0.3f, 0.6f, 0.9f, 1.2f};

	unsigned int index = 0;
	float f_k = 0.0f, s_k = 0.0f, s_k_1 = 0.0f;
	float x = 0.0f, y = 0.0f;
	// 生成先验框,下面思路基本和 python 的一样
	for(int i = 0; i < LONGTH; i++){
		f_k = input_image_size / steps[i];
		s_k = min_sk[i] / input_image_size;
		s_k_1 = sqrt(s_k * (max_sk[i] / input_image_size)); 
		
		for(int j = 0; j < (int)feature_map_size[i]; j++){
			for(int k = 0; k < (int)feature_map_size[i]; k++){
				x = ((float)k + 0.5) / f_k;
				y = ((float)j + 0.5) / f_k;
				
				if (i < 4){
					for(int m = 0; m < (a1_LONGTH - 1); m++){
						
						priorbox[index].anchor_cx = coordinate_clip(0.0f, 1.0f, x);
						priorbox[index].anchor_cy = coordinate_clip(0.0f, 1.0f, y);
						priorbox[index].anchor_w = coordinate_clip(0.0f, 1.0f, s_k * sqrt(aT_1[m]));
						priorbox[index].anchor_h = coordinate_clip(0.0f, 1.0f, s_k / sqrt(aT_1[m]));
						index += 1;
					}
					priorbox[index].anchor_cx = coordinate_clip(0.0f, 1.0f, x);
					priorbox[index].anchor_cy = coordinate_clip(0.0f, 1.0f, y);
					priorbox[index].anchor_w = coordinate_clip(0.0f, 1.0f, s_k_1 * sqrt(aT_1[4]));
					priorbox[index].anchor_h = coordinate_clip(0.0f, 1.0f, s_k_1 / sqrt(aT_1[4]));
					index += 1;
				}else{
					for(int m = 0; m < a2_LONGTH; m++){
						
						priorbox[index].anchor_cx = coordinate_clip(0.0f, 1.0f, x);
						priorbox[index].anchor_cy = coordinate_clip(0.0f, 1.0f, y);
						priorbox[index].anchor_w = coordinate_clip(0.0f, 1.0f, s_k * sqrt(aT_2[m]));
						priorbox[index].anchor_h = coordinate_clip(0.0f, 1.0f, s_k / sqrt(aT_2[m]));
						index += 1;
						
					}
				}
			}
		}
	}
	
	return 0;
}

void get_pred_bbox(SAMPLE_SVP_NNIE_PARAM_S *pstNnieParam, PriorBox *priorbox, Node *predbbox_head_node, float cls_threshold_) {
	int cls_threshold = (int)(cls_threshold_ * 4096.0f);
	
	// bbox output
    HI_S32 reg_w = pstNnieParam->astSegData[0].astDst[0].unShape.stWhc.u32Width; // 1
    HI_S32 reg_h = pstNnieParam->astSegData[0].astDst[0].unShape.stWhc.u32Height; // 4
    HI_S32 reg_c = pstNnieParam->astSegData[0].astDst[0].unShape.stWhc.u32Chn; // 9690

	// confidence output
    HI_S32 cls_w = pstNnieParam->astSegData[0].astDst[1].unShape.stWhc.u32Width; // 1
    HI_S32 cls_h = pstNnieParam->astSegData[0].astDst[1].unShape.stWhc.u32Height; // 2
    HI_S32 cls_c = pstNnieParam->astSegData[0].astDst[1].unShape.stWhc.u32Chn; // 9690

    HI_S32* reg = (HI_S32* )((HI_U8* )pstNnieParam->astSegData[0].astDst[0].u64VirAddr);// bbox坐标 输出内存首地址
    HI_S32* cls = (HI_S32* )((HI_U8* )pstNnieParam->astSegData[0].astDst[1].u64VirAddr);// confidence 输出内存首地址
    
    //{
    	//printf("%d %d %d \n%d %d %d \n", reg_w, reg_h, reg_c, cls_w, cls_h, cls_c);
        //printf("cls %d %d %d\n", cls_c, cls_h, cls_w);
        //printf("reg %d %d %d\n", reg_c, reg_h, reg_w);
    	//printf("0 blob u32Stride = %d\n", pstNnieParam->astSegData[0].astDst[0].u32Stride); // blob type = SVP_BLOB_TYPE_S32
		//printf("1 blob u32Stride = %d\n", pstNnieParam->astSegData[0].astDst[1].u32Stride); // u32Num = 1, stride = 16
    //}
    
	// The stride = 16 byte and per confidence output occupy 4 byte memory, so stride in pixel is 4. And we use this value to calculate address.
	int stride = pstNnieParam->astSegData[0].astDst[1].u32Stride / sizeof(int); 
	HI_S32* start = cls + stride; // we need person-class's confidence, so the begin address should add stride .
	int confidence = 0;
	int proposal = 0;
	
	float anchor_cx = 0;
	float anchor_cy = 0;
	float anchor_w = 0;
	float anchor_h = 0;
	
	float locate_cx = 0;
	float locate_cy = 0;
	float locate_w = 0;
	float locate_h = 0;
	float anchor_variance[2] = {0.1f, 0.2f}; // 先验框的方差
	
	Node *per_result = NULL;
	Node *last_node = predbbox_head_node;

	float cx = 0, cy = 0, w = 0, h = 0;
	
	for(int i = 0;i < PriorBox_Num; i++){
		// nnie 置信度输出
		confidence = *(start + i * stride * cls_h);
		if (confidence > cls_threshold){
			
			anchor_cx = priorbox[i].anchor_cx;
			anchor_cy = priorbox[i].anchor_cy;
			anchor_w = priorbox[i].anchor_w;
			anchor_h = priorbox[i].anchor_h;
			// nnie bbox输出，这些不要照抄，平台、模型不一样，就要变的
			locate_cx = *(reg + i * stride * reg_h)/4096.0f;
			locate_cy = *(reg + i * stride * reg_h + stride)/4096.0f;
			locate_w = *(reg + i * stride * reg_h + stride * 2)/4096.0f;
			locate_h = *(reg + i * stride * reg_h + stride * 3)/4096.0f;
			
			per_result = (Node *)malloc(sizeof(Node));
			memset(per_result, 0, sizeof(Node));
			per_result->next = NULL;
			last_node->next = per_result; // link node 
			last_node = last_node->next; // let last_node point to it's next node
			
			cx = anchor_cx + locate_cx * anchor_variance[0] * anchor_w;
			cy = anchor_cy + locate_cy * anchor_variance[0] * anchor_h;
			w = anchor_w * (float)(exp(locate_w * anchor_variance[1]));
			h = anchor_h * (float)(exp(locate_h * anchor_variance[1]));
			// 储存结果
			per_result->bbox.bbox_x_left = cx - w / 2.0f;
			per_result->bbox.bbox_y_left = cy - h / 2.0f;
			per_result->bbox.bbox_x_right = cx + w / 2.0f;
			per_result->bbox.bbox_y_right = cy + h / 2.0f;
			per_result->confidence = confidence;
			
			per_result = NULL; // 赋值为 NULL，方便下次循环
		}
		
	}
}